/*
 * Copyright (c) 2020-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS,  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */
#include <tests/iterator/pair_iterator_test.cuh>

#include <cudf_test/random.hpp>

#include <rmm/exec_policy.hpp>

#include <thrust/iterator/transform_iterator.h>
#include <thrust/pair.h>
#include <thrust/reduce.h>

using TestingTypes = cudf::test::NumericTypes;

namespace cudf {
// To print meanvar for debug.
// Needs to be in the cudf namespace for ADL
template <typename T>
std::ostream& operator<<(std::ostream& os, cudf::meanvar<T> const& rhs)
{
  return os << "[" << rhs.value << ", " << rhs.value_squared << ", " << rhs.count << "] ";
};
}  // namespace cudf

template <typename T>
struct NumericPairIteratorTest : public IteratorTest<T> {};

TYPED_TEST_SUITE(NumericPairIteratorTest, TestingTypes);
TYPED_TEST(NumericPairIteratorTest, nonull_pair_iterator) { nonull_pair_iterator(*this); }
TYPED_TEST(NumericPairIteratorTest, null_pair_iterator) { null_pair_iterator(*this); }

// Transformers and Operators for pair_iterator test
template <typename ElementType>
struct transformer_pair_meanvar {
  using ResultType = thrust::pair<cudf::meanvar<ElementType>, bool>;

  CUDF_HOST_DEVICE inline ResultType operator()(thrust::pair<ElementType, bool> const& pair)
  {
    ElementType v = pair.first;
    return {{v, static_cast<ElementType>(v * v), (pair.second) ? 1 : 0}, pair.second};
  };
};

struct sum_if_not_null {
  template <typename T>
  CUDF_HOST_DEVICE inline thrust::pair<T, bool> operator()(thrust::pair<T, bool> const& lhs,
                                                           thrust::pair<T, bool> const& rhs)
  {
    if (lhs.second & rhs.second)
      return {lhs.first + rhs.first, true};
    else if (lhs.second)
      return {lhs};
    else
      return {rhs};
  }
};

// TODO: enable this test also at __CUDACC_DEBUG__
// This test causes fatal compilation error only at device debug mode.
// Workaround: exclude this test only at device debug mode.
#if !defined(__CUDACC_DEBUG__)
// This test computes `count`, `sum`, `sum_of_squares` at a single reduction call.
// It would be useful for `var`, `std` operation
TYPED_TEST(NumericPairIteratorTest, mean_var_output)
{
  using T        = TypeParam;
  using T_output = cudf::meanvar<T>;
  transformer_pair_meanvar<T> transformer{};

  int const column_size{5000};
  const T init{0};

  // data and valid arrays
  std::vector<T> host_values(column_size);
  std::vector<bool> host_bools(column_size);

  if constexpr (std::is_floating_point<T>()) {
    cudf::test::UniformRandomGenerator<int32_t> rng;
    std::generate(host_values.begin(), host_values.end(), [&rng]() {
      return static_cast<T>(rng.generate() % 10);  // reduces float-op errors
    });
  } else {
    cudf::test::UniformRandomGenerator<T> rng;
    std::generate(host_values.begin(), host_values.end(), [&rng]() { return rng.generate(); });
  }

  cudf::test::UniformRandomGenerator<bool> rbg;
  std::generate(host_bools.begin(), host_bools.end(), [&rbg]() { return rbg.generate(); });

  cudf::test::fixed_width_column_wrapper<TypeParam> w_col(
    host_values.begin(), host_values.end(), host_bools.begin());
  auto d_col = cudf::column_device_view::create(w_col);

  // calculate expected values by CPU
  T_output expected_value;

  expected_value.count = d_col->size() - static_cast<cudf::column_view>(w_col).null_count();

  std::vector<T> replaced_array(d_col->size());
  std::transform(host_values.begin(),
                 host_values.end(),
                 host_bools.begin(),
                 replaced_array.begin(),
                 [&](T x, bool b) { return (b) ? static_cast<T>(x) : init; });

  expected_value.count = d_col->size() - static_cast<cudf::column_view>(w_col).null_count();
  expected_value.value = std::accumulate(replaced_array.begin(), replaced_array.end(), T{0});
  expected_value.value_squared = std::accumulate(
    replaced_array.begin(), replaced_array.end(), T{0}, [](T acc, T i) { return acc + i * i; });

  // GPU test
  auto it_dev         = d_col->pair_begin<T, true>();
  auto it_dev_squared = thrust::make_transform_iterator(it_dev, transformer);
  auto result         = thrust::reduce(rmm::exec_policy(cudf::get_default_stream()),
                               it_dev_squared,
                               it_dev_squared + d_col->size(),
                               thrust::make_pair(T_output{}, true),
                               sum_if_not_null{});
  if constexpr (not std::is_floating_point<T>()) {
    EXPECT_EQ(expected_value, result.first) << "pair iterator reduction sum";
  } else {
    EXPECT_NEAR(expected_value.value, result.first.value, 1e-3) << "pair iterator reduction sum";
    EXPECT_NEAR(expected_value.value_squared, result.first.value_squared, 1e-3)
      << "pair iterator reduction sum squared";
    EXPECT_EQ(expected_value.count, result.first.count) << "pair iterator reduction count";
  }
}
#endif
